# -*- coding: utf-8 -*-
"""
Created on Mon May 13 15:52:09 2019

@author: Salem
"""

import numpy as np
import numpy.linalg as la
import numpy.random as npr


import scipy.optimize as op

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

from stl import mesh

#-------------------------------------
from pathlib import WindowsPath
from pathlib import Path
import os

cwd = Path(os.getcwd())

Dropbox = Path(__file__).parents[3] 

mesh_dir = Dropbox / "With Arkhat"  / "Meshes" 
results_dir = cwd  / "Run Results"
fit_cubic_dir = mesh_dir / "Fits Cubic" 
fit_quartic_dir = mesh_dir / "Fits Quartic" 
tip_fit_cubic_dir = mesh_dir / "Tips Fits Cubic" 
edge_dir = results_dir / "Top Edges A"
top_open_dir = mesh_dir / "Tops Open" 
tip_open_dir = mesh_dir / "Tips Open" 
#-------------------------------------




#names of beaks
beak1 =  "G.FuliginosaA"
beak2 =  "G.FuliginosaB"
beak3 =  "G.FuliginosaC"
beak4 =  "G.FuliginosaD"
beak5 =  "G.FuliginosaE"
beak6 = "G.FuliginosaF"
beak7 =  "G.FortisA"
beak8 =  "G.FortisB"
beak9 =  "G.FortisC"
beak10 =  "G.FortisD"
beak11 =  "G.FortisE"
beak12 = "G.MagnirostrisA"
beak13 = "G.MagnirostrisB"
beak14 = "G.MagnirostrisC"
beak15 = "G.MagnirostrisD"
beak16 = "G.MagnirostrisE"
beak17 =  "G.ConirostrisA"
beak18 =  "G.ConirostrisB"
beak19 =  "G.ConirostrisC"
beak20 =  "G.ConirostrisD"
beak21 =  "G.ConirostrisE"
beak22 =  "G.ConirostrisF"
beak23 =  "G.ConirostrisG"
beak24 =  "G.ScandensA"
beak25 =  "G.ScandensB"
beak26 =  "G.SeptentrionalistA"
beak27 =  "G.SeptentrionalistB"
beak28 =  "G.DifficilisA"
beak29 =  "G.DifficilisB"
beak30 =  "C.PallidusA"
beak31 =  "C.PallidusB"
beak32 = "C.PallidusC"
beak33 = "C.PallidusD"
beak34 =  "C.ParvulusA"
beak35 =  "C.ParvulusB"
beak36 =  "C.ParvulusC"
beak37 =  "C.ParvulusD"
beak38 =  "C.ParvulusE"
beak39 =  "C.PsittaculaA"
beak40 = "C.PsittaculaB"
beak41 =  "C.PsittaculaC"
beak42 =  "C.PsittaculaD"
beak43 =  "P2.InornataA"
beak44 =  "P2.InornataB"
beak45 = "C2.OlivaceaA"
beak46 =  "C2.OlivaceaB"
beak47 =  "C2.OlivaceaC"
beak48 =  "C2.FuscaA"
beak49 =  "C2.FuscaB"
beak50 =  "P.CrassirostrisA"
beak51 =  "P.CrassirostrisB"
beak52 =  "P.CrassirostrisC"
beak53 =  "P.CrassirostrisD"
beak54 =  "P.CrassirostrisE"
beak55 =  "T.BicolorA"
beak56 =  "T.BicolorB"
beak57 =  "T.BicolorC"
beak58 =  "T.BicolorD"
beak59 =  "T.BicolorE"


names = np.array([beak1, beak2, beak3, beak4, beak5, beak6, beak7, beak8, beak9, 
                  beak10, beak11, beak12, beak13, beak14, beak15, beak16, beak17, beak18,
                  beak19, beak20, beak21, beak22, beak23, beak24, beak25, beak26, beak27,
                  beak28, beak29, beak30, beak31, beak32, beak33, beak34, beak35, beak36, beak37,
                  beak38, beak39, beak40, beak41, beak42, beak43, beak44, beak45, beak46, beak47,
                  beak48, beak49, beak50, beak51, beak52, beak53, beak54, beak55, beak56, beak57, beak58, beak59])



#=========================================================================================================================
#Takes an stl file and gets the connectivity data: vertex indices and where they appear in faces and edges
#========================================================================================================================= 
def get_connectivity(mesh_name):
    
    stl_mesh = mesh.Mesh.from_file(mesh_name)
    
    points  = stl_mesh.points.round(decimals=4)
    
    vertices, inverts  = np.unique(points.reshape((points.size//3, 3)), axis=0, return_inverse=True)
    
    faces = inverts.reshape((inverts.size//3, 3)) #indices of triangles. 
    
    #edges = np.unique(faces[:, [[0,1], [0,2], [1,2]]].reshape((faces.size, 2)), axis=0)
    
    
    return  vertices, faces
#=========================================================================================================================
#xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
#========================================================================================================================= 



#=========================================================================================================================
# takes the vertices and faces and generates then saves an stl mesh
#========================================================================================================================= 
def make_stl(vertices, faces, mesh_name):
    
    # Create the mesh
    stl_mesh = mesh.Mesh(np.zeros(faces.shape[0], dtype=mesh.Mesh.dtype))
    for i, f in enumerate(faces):
        for j in range(3):
            stl_mesh.vectors[i][j] = vertices[f[j],:]


    print("\n saving to: ", mesh_name, "\n")

    # Write the mesh to file mesh_name
    stl_mesh.save(mesh_name)
    
    return 
#=========================================================================================================================
#xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
#========================================================================================================================= 

#=========================================================================================================================
# apply find fit to all beaks in the directory
#========================================================================================================================= 
def find_all_fits(mesh_names=names, save_name = "fit_quartic_params.npy", poly_degree = 4, save_dir=fit_quartic_dir,
                  load_dir=top_open_dir):
    
    if poly_degree==4: num_params = 21
    elif poly_degree==3: num_params = 16
    elif poly_degree==6: num_params = 34
    
    fit_params = np.zeros((mesh_names.size, num_params))
    #21 if degree is 4 #### 16 if degree is 3
    
    for i, name in enumerate(mesh_names):fit_params[i] = find_fit(name, poly_degree, save_dir, load_dir)
    
    np.savetxt(results_dir / save_name, fit_params)
    
    return

#=========================================================================================================================
#xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
#=========================================================================================================================  
    


#=========================================================================================================================
# find fit
#========================================================================================================================= 
def find_fit(mesh_name, poly_degree, save_dir, load_dir, isSave=True, 
             spatial_dims=3):
    from sklearn.preprocessing import PolynomialFeatures
    from sklearn import linear_model
    
    vertices, faces = get_connectivity(load_dir / (mesh_name + '.stl'))     
    
    minX = vertices[:,0].min(); minY = vertices[:,1].min(); minZ= vertices[:,2].min()
    maxX = vertices[:,0].max();maxY = vertices[:,1].max(); maxZ= vertices[:,2].max()
    
    X = vertices[:, :spatial_dims-1]
    vector = vertices[:, spatial_dims-1]
    
    poly = PolynomialFeatures(degree=poly_degree)
    X_ = poly.fit_transform(X)
    predict_ = poly.fit_transform(X)
    
    clf = linear_model.LinearRegression()
    clf.fit(X_, vector)
    
    
    vertices[:, 2] = clf.predict(predict_)
    
    if isSave:
        make_stl(vertices, faces, save_dir / (mesh_name + '.stl'))    
    
    print(poly.get_feature_names() )
    return np.hstack((clf.coef_, [minX, maxX, minY, maxY, minZ, maxZ]))
#=========================================================================================================================
#xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
#=========================================================================================================================  
    

#=========================================================================================================================
# Find edge vertices
#========================================================================================================================= 
def find_edge_verts(mesh_name, save_dir=edge_dir, load_dir=top_open_dir, isSave=True, num_intervals=50):
    
    
    vertices, faces = get_connectivity(load_dir / (mesh_name + '.stl'))     
    
    minX = vertices[:,0].min()
    maxX = vertices[:,0].max()
    
    #array containing the x values
    x_values = np.linspace(minX, maxX, num_intervals + 1)
    
    min_edge = []; max_edge = []
    
    np.savetxt(results_dir / "verts.npy", vertices)
                
    
    for i in range(num_intervals):
        
        #get the vertices contained in the segment
        segment=  []
        for v in vertices:
            if (v[0] > x_values[i]) and (v[0] < x_values[i + 1]):
                segment.append(v)
        segment = np.array(segment)
        
        #find the vertex with maximum and minimum y value
        
        values = segment[:, 1]
        min_edge.append(segment[np.argmin(values)])
        max_edge.append(segment[np.argmax(values)])
    
    
    
    if isSave:
        np.savetxt(save_dir / (mesh_name + '.npy'), np.vstack((min_edge, max_edge)))
        #np.savetxt(save_dir / (mesh_name + '.npy'), min_edge)
    
    return np.array(min_edge), np.array(max_edge)
#=========================================================================================================================
#xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
#=========================================================================================================================  













































